#################################################################################
# Copyright (c) 2018-2021, Texas Instruments Incorporated - http://www.ti.com
# All Rights Reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice, this
#   list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright notice,
#   this list of conditions and the following disclaimer in the documentation
#   and/or other materials provided with the distribution.
#
# * Neither the name of the copyright holder nor the names of its
#   contributors may be used to endorse or promote products derived from
#   this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################

import random
import numpy as np
import fast_histogram
import torch


########################################################################
# numpy implementation of range with history
########################################################################

class RangeEstimator:
    def __init__(self, range_shrink_percentile=0.01, range_update_factor_min=0.001,
                 mode='histogram_ema_minmax', fast_mode=True):
        self.range_shrink_percentile = range_shrink_percentile
        self.range_update_factor_min = range_update_factor_min
        self.fast_mode = fast_mode
        self.mode = mode
        self.histogram_bins = 1024
        self.histogram_scale = 100.0 # make it percentile
        self.eps = 1e-16
        self.fast_stride = 2
        self.history = dict()
        self.counter = dict()


    def __call__(self, tensor):
        device_str = str(tensor.device)
        history = self.history.get(device_str, None)
        counter = self.counter.get(device_str, 0)

        # downsample the tensor for fast_mode
        fast_stride2 = self.fast_stride * 2
        if self.fast_mode and len(tensor.size()) == 4 and (tensor.size(2) > fast_stride2) and (tensor.size(3) > fast_stride2):
            r_start = random.randint(0, self.fast_stride - 1)
            c_start = random.randint(0, self.fast_stride - 1)
            tensor = tensor[..., r_start::self.fast_stride, c_start::self.fast_stride]
        #
        tensor = tensor.data.cpu().numpy() if isinstance(tensor, torch.Tensor) else tensor

        update_factor = 1.0 / (counter+1)
        update_factor = max(update_factor, self.range_update_factor_min)

        # using minmax_ema method when range_shrink_percentile = 0.0 for backward compatibility
        if self.mode == 'minmax_ema' or self.range_shrink_percentile == 0.0: #min/max ranges and ema
            mn, mx = self._tensor_minmax(tensor)
            if history is not None:
                mn = self.history['min']*(1-update_factor) + mn*update_factor
                mx = self.history['max']*(1-update_factor) + mx*update_factor
            #
            hist_merged = dict(min=mn, max=mx)
        elif self.mode == 'histogram_minmax_ema': # approximate - histogram is done, but ema is on min/max values
            hist = self._fast_histogram(tensor)
            mn, mx = self._search_histogram_minmax(hist)
            if history is not None:
                mn = history['min']*(1-update_factor) + mn*update_factor
                mx = history['max']*(1-update_factor) + mx*update_factor
            #
            hist_merged = dict(min=mn, max=mx)
        elif self.mode == 'histogram_ema_minmax': # most accurate - ema is performed on histogram
            hist = self._fast_histogram(tensor)
            if history is not None:
                hist_list = [history, hist]
                scale_factors = [(1-update_factor), update_factor]
                minmax_range = (min([h['edges'][0] for h in hist_list]),
                                max([h['edges'][-1] for h in hist_list]))
                hist_merged = self._merge_histograms_to_range(hist_list, minmax_range=minmax_range,
                                                            scale_factors=scale_factors)
            else:
                hist_merged = hist
            #
            mn, mx = self._search_histogram_minmax(hist_merged)

            # avoid the counts from overflowing
            if max(hist_merged['counts']) > 1e12:
                hist_merged['counts'] = hist_merged['counts'] / 2

            # TODO: this is not working
            # # avoid the ranges from increasing forever
            # margin_factor = 4
            # mn_margin = (mn*margin_factor if mn<0 else mn/margin_factor) - 1
            # mn_where = np.where(hist_merged['edges'] < mn_margin)
            # hist_merged['edges'][mn_where] = mn
            # mn_where = np.array(mn_where).clip(max=self.histogram_bins-1)
            # hist_merged['counts'][mn_where] = 0
            # #
            # mx_margin = (mx/margin_factor if mx<0 else mx*margin_factor) + 1
            # mx_where = np.where(hist_merged['edges'] > mx_margin)
            # hist_merged['edges'][mx_where] = mx
            # mx_where = np.array(mx_where).clip(max=self.histogram_bins-1)
            # hist_merged['counts'][mx_where] = 0
        else:
            assert False, f'invalid value for mode: {self.mode}'
        #
        self.history[device_str] = hist_merged
        self.counter[device_str] = (counter + 1)
        return mn, mx


    def _merge_histograms_to_range(self, hist_list, minmax_range, scale_factors=None):
        assert isinstance(minmax_range, (list,tuple)) and len(minmax_range) == 2, 'range must have size 2'
        hist_counts = np.zeros((self.histogram_bins,))
        hist_edges = np.linspace(minmax_range[0], minmax_range[1], num=self.histogram_bins+1)
        hist_merged = dict(counts=hist_counts, edges=hist_edges)
        for i in range(len(hist_list)):
            scale_factor = scale_factors[i] if (scale_factors is not None) else 1.0
            hist_merged = self._merge_histograms(hist_list[i], hist_merged,
                                            scale_factors=(scale_factor, 1.0))
        #
        return hist_merged


    # merge histsrc onto histnew
    # note that this is NOT an inplace operation -
    # result is returned in a new buffer - not in histdst
    def _merge_histograms(self, histprev, histnew, scale_factors=(1.0,1.0)):
        if histnew is None:
            return histprev
        elif histprev is None:
            return histnew
        #
        histprev_interp = self._interpolate_histograms(histprev, histnew['edges'])
        histmerged_counts = histprev_interp['counts']*scale_factors[0] + histnew['counts']*scale_factors[1]
        return dict(counts=histmerged_counts, edges=histnew['edges'])


    def _interpolate_histograms(self, hist, hist_edges_new):
        hist_counts_cum = np.hstack([0, np.cumsum(hist['counts'])])
        hist_counts_cum_new = np.interp(hist_edges_new, hist['edges'], hist_counts_cum)
        hist_counts_new = np.diff(hist_counts_cum_new)
        return dict(counts=hist_counts_new, edges=hist_edges_new)


    def _tensor_minmax(self, tensor):
        return np.min(tensor), np.max(tensor)


    def _numpy_histogram(self, tensor):
        hist_out = np.histogram(tensor, bins=self.histogram_bins)
        return dict(counts=hist_out[0], edges=hist_out[1])


    def _fast_histogram(self, tensor, minmax_range=None):
        # same functionality as self._numpy_histogram(tensor)
        # but a faster implementation using fast_histogram
        if minmax_range is not None:
            pass
        if isinstance(tensor, np.ndarray):
            minmax_range = (tensor.min(), tensor.max())
        else:
            minmax_range = [min(tensor), max(tensor)]
        #
        hist_counts = fast_histogram.histogram1d(tensor, range=minmax_range, bins=self.histogram_bins)
        hist_edges = np.linspace(minmax_range[0], minmax_range[1], num=self.histogram_bins+1)
        return dict(counts=hist_counts, edges=hist_edges)


    def _search_histogram_minmax(self, hist):
        hist_min = hist['edges'][0]
        hist_max = hist['edges'][-1]
        hist_delta = (hist_max - hist_min) / self.histogram_bins

        hist_sum = np.sum(hist['counts'])
        hist_sum = self.eps if hist_sum == 0 else hist_sum
        hist_norm = hist['counts'] * self.histogram_scale / hist_sum
        hist_cum = np.cumsum(hist_norm)

        mn_idx = np.where(hist_cum >= self.range_shrink_percentile)[0][0]
        mx_idx = np.where(hist_cum <= (self.histogram_scale-self.range_shrink_percentile))[0][-1]
        mn = hist_min + mn_idx * hist_delta
        mx = hist_min + mx_idx * hist_delta
        return mn, mx

